Pytorch

您所在的位置:网站首页 卷积 边缘检测 Pytorch

Pytorch

#Pytorch| 来源: 网络整理| 查看: 265

题目:对任意图片进行简单卷积操作,并提取图片的边缘信息

文章目录 小练习~卷积实例小练习~边缘特征提取实例关于边缘检测的原理Sobel滤波器Scharr滤波器

本次使用图片:在这里插入图片描述

小练习~卷积实例

首先将图片读取进来->转为灰度图->转为numpy数组:

# 读图 image = Image.open('../data/cat.png') # 转为灰度图 image = image.convert("L") # 转为numpy数组 image_np = np.array(image)

此时得到的numpy数组维度为二维,可以使用shape输出图片大小(719,719) 。

将 (719,719) 变成(11719*719)的张量,方便后面卷积操作。

h , w = image_np.shape image_tensor = torch.from_numpy(image_np.reshape(1,1,h,w)).float() # torch.from_numpy ()方法把数组转换成张量,

此时进行shape输出得到四维张量 torch.Size([1, 1, 719, 719])。

设置卷积核的大小,定义卷积层。

kersize = 5 # 5*5的卷积核大小 ker = torch.ones(kersize, kersize, dtype=torch.float32) * -1 # 元素全部设置为 -1 conv2d = torch.nn.Conv2d(1, 1, (kersize, kersize), bias=False) # 设置卷积网络 ker = ker.reshape((1, 1, kersize, kersize))# 将ker变成四维张量 conv2d.weight.data = ker # 初始化权重

将数据放进去卷积,然后压缩一下维度: 这里的squeeze()函数就是压缩维度,可以去掉其中为1的维度。

image_out = conv2d(image_tensor) image_out = image_out.data.squeeze() 去维度之前和之后: torch.Size([1, 1, 715, 715]) torch.Size([715, 715])

这样的二维数据就可以拿去画图了:

plt.axis('off') # 不显示坐标轴 plt.imshow(image_out, cmap=plt.cm.gray) # cmap表示给图上黑白色 plt.show()

完整代码:

## 普通卷积 ## 普通卷积 from PIL import Image import torch import matplotlib.pyplot as plt import numpy as np from torch import nn image = Image.open('../data/cat.png') image = image.convert("L") image_np = np.array(image) h, w = image_np.shape image_tensor = torch.from_numpy(image_np.reshape(1, 1, h, w)).float() kersize = 5 ker = torch.ones(kersize, kersize, dtype=torch.float32) * -1 print(ker) conv2d = torch.nn.Conv2d(1, 1, (kersize, kersize), bias=False) ker = ker.reshape((1, 1, kersize, kersize)) conv2d.weight.data = ker image_out = conv2d(image_tensor) image_out = image_out.data.squeeze() plt.axis('off') plt.imshow(image_out, cmap=plt.cm.gray) plt.show()

可以看到结果: 在这里插入图片描述 …额 有点吓人。

小练习~边缘特征提取实例

在刚才的程序中,我们使用了卷积核维度5*5。

现在用于边缘检测,我们初始化用于边缘检测的卷积核。 只需要加一行代码

ker[2, 2] = 24 关于边缘检测的原理

其实就是设置了特殊的卷积核(滤波器)达到了检测边缘的目的。

在这里插入图片描述 可以看一下上面这张图, 设置的卷积核(滤波器)设计的非常巧妙 一列1 一列0 一列-1(称为垂直边缘检测)。 这就导致颜色相同的区域与这个卷积核做卷积得到的数据近0。而越接近0颜色就越浅。 区域内颜色相差越大,得到的数字也就越大,颜色就会越深。 最后就会描绘出图像边缘。 一行1 一行 0 一行-1 就是水平边缘检测。

Sobel滤波器

,其增加了中间一行的权重,加强处理图像中央的元素点。 在这里插入图片描述

Scharr滤波器

。也是一种垂直边缘检测,反转90度就变成了水平边缘检测。 在这里插入图片描述

本题中用到的 全是-1 仅仅中间是24这个情况。。没懂。不过查阅资料发现有点类似于 拉普拉斯算子卷积核的情况。 先不管这个为什么设置中间为24了 ,后面懂了再回来补吧。

另外: 我们也可以直接将普通卷积操作和边缘检测卷积操作两个合并到一起,在卷积层设置的时候输出通道数改成2,然后再绘图的时候绘制到一起,看下面代码里的注释。

完整代码:

from PIL import Image import torch import matplotlib.pyplot as plt import numpy as np # 在Image转Tensor过程中,图片的格式会由: H * W * C的格式转为: C * H * W格式。 from torch import nn # 读图 image = Image.open('../data/cat.png') # 转为灰度图 image = image.convert("L") # 转为numpy数组 image_np = np.array(image) # plt.figure(figsize=(6,6)) # 绘图并指定长和高 # # 参数cmap将标量数据映射到图片 cmap = plt.cm.gray 返回线性灰度色图 # # cmap参数接受一个值(每个值代表一种配色方案),并将该值对应的颜色图分配给当前图窗。 # # 如果将当前图窗比作一幅简笔画,则cmap就代表颜料盘的配色,用所提供的颜料盘自动给当前简笔画上色,就是cmap所做的事。 # plt.imshow(image_np,cmap=plt.cm.gray) # plt.axis("off") # 不显示坐标轴 # plt.show() # print(image_np.shape) # 得到 结果 (719, 719) # 将上述数组转换成 1*1*719*719的张量 h , w = image_np.shape image_tensor = torch.from_numpy(image_np.reshape(1,1,h,w)).float() # torch.from_numpy ()方法把数组转换成张量, # print(image_tensor.shape) # 得到torch.Size([1, 1, 719, 719]) kersize = 5 ker = torch.ones(kersize,kersize,dtype=torch.float32)*-1 ker[2,2] = 24 ker = ker.reshape((1,1,kersize,kersize)) # 边缘卷积核 conv2d = torch.nn.Conv2d(1,2,(kersize,kersize),bias=False) # 设置卷积核 # print(ker) # print(conv2d.weight.data.shape) # 权重维度 2*1*5*5 输出通道数,输入通道数,卷积核长宽 conv2d.weight.data[0] = ker print(conv2d.weight.data) # 第一个维度的权重用于进行边缘提取,第二个卷积核是随机数 image_out = conv2d(image_tensor) print(image_out.shape) image_out = image_out.data.squeeze() # 数组的维度压缩,去掉其中为1的维度 print(image_out.shape) # 要注意此时的 image_out里面有两个 715*715的图 torch.Size([2, 715, 715]) 三维张量 print(image_out) plt.figure(figsize=(18,6)) # 表示输出的长和宽 plt.subplot(1,3,1) # 1行 2列 然后按照从左到右,从上到下的顺序对每个子区域进行编号,左上的子区域的编号为1,最后那个参数指定显示的区域编号 plt.imshow(image_out[0],cmap=plt.cm.gray) # 拿出来第0个 就是秒轮廓的那个 plt.axis('off') plt.subplot(1,3,2) # 写第二个图 plt.imshow(image_out[1],cmap=plt.cm.gray) plt.axis('off') plt.subplot(1,3,3) plt.imshow(image_np,cmap=plt.cm.gray) plt.axis('off') plt.show()


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3